-
Notifications
You must be signed in to change notification settings - Fork 19.6k
Added OrbaxCheckpoint for keras 3.0 for Data centric saving and restore #21762
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Added OrbaxCheckpoint for keras 3.0 for Data centric saving and restore #21762
Conversation
…re Supports following feature - Asynchronous Checkpointing - Composite Checkpointing - Preservation Policies - Save Decision Policies - Transformations - Custom Handlers
Summary of ChangesHello @amitsrivastava78, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request integrates Orbax checkpointing into Keras 3.0, providing a robust and flexible mechanism for saving and restoring training progress. The new Highlights
Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here. You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension. Footnotes
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request introduces OrbaxCheckpoint, a new Keras callback for advanced checkpointing using the Orbax library. This is a significant feature addition that enables asynchronous saving, composite checkpoints, and other powerful capabilities. The implementation is extensive and is supported by a comprehensive suite of tests.
My review has identified several important issues that need attention. There are critical correctness and performance bugs in the main implementation: the batch-based saving logic is flawed, and the asynchronous saving feature is effectively disabled by blocking calls. Additionally, some features are incomplete, and there are minor areas for improvement in the tests to enhance maintainability. I have provided specific suggestions to address these points. After these fixes, this will be a very valuable addition to Keras.
| def __init__( | ||
| self, | ||
| directory, | ||
| monitor="val_loss", | ||
| verbose=0, | ||
| save_best_only=False, | ||
| mode="auto", | ||
| save_freq="epoch", | ||
| max_to_keep=5, | ||
| keep_period=None, | ||
| initial_value_threshold=None, | ||
| save_optimizer_state=True, | ||
| save_on_background=True, | ||
| save_metadata=None, | ||
| save_data_iterator=None, | ||
| save_metrics_state=False, | ||
| async_timeout_secs=600, | ||
| enable_background_delete=False, | ||
| post_finalization_callback=None, | ||
| save_transforms=None, | ||
| save_decision_policy=None, | ||
| save_interval=None, | ||
| ): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The __init__ method has 16 arguments, which is quite high. The Keras API design guidelines suggest reconsidering signatures with more than 6-7 arguments.1 While I understand the need to expose Orbax's functionality, it might be worth exploring if some of these could be grouped into a configuration object to improve readability and usability, similar to how ocp.CheckpointManagerOptions is used internally.
Style Guide References
Footnotes
-
The style guide recommends that functions with more than 6-7 arguments should be re-evaluated for simplification, possibly by breaking them into smaller objects or modular pieces. ↩
Codecov Report❌ Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## master #21762 +/- ##
==========================================
- Coverage 82.69% 82.65% -0.04%
==========================================
Files 573 578 +5
Lines 58888 59670 +782
Branches 9218 9374 +156
==========================================
+ Hits 48696 49319 +623
- Misses 7845 7929 +84
- Partials 2347 2422 +75
Flags with carried forward coverage won't be shown. Click here to find out more. ☔ View full report in Codecov by Sentry. 🚀 New features to boost your workflow:
|
hertschuh
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the PR. This checkpointing system has a ton of features!
Quick first pass.
hertschuh
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
A couple more comments I forgot.
- Remove conditional export decorator to ensure OrbaxCheckpoint is always available - Remove unnecessary exception handling in state tree operations - Update process index check comment for clarity - Format code to comply with 80-character line limit - Add distribution_lib modules for backend-specific distributed training support
- Remove unused 'result' variable in _reconstruct_state_tree_with_values - Fix long comment line in test file - Apply code formatting changes
…st handling - Implement OrbaxCheckpoint callback for async checkpointing with state tree handling - Add conditional exports for optional orbax-checkpoint dependency - Use pytest.importorskip for clean optional dependency testing - Ensure graceful handling when orbax-checkpoint is not installed
hertschuh
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The JAX implementation of def process_id() is missing.
General questions:
- Does this as-is support all backends?
- Does this support JAX sharding? I don't see anything related to sharing (which may be normal). What about re-sharding?
8097bd2 to
276ea9a
Compare
- Preserve nested state tree structures instead of flattening for better layer name preservation - Add backward compatibility for old flattened format checkpoints - Simplify test class by using self.get_temp_dir() instead of setUp/tearDown - Remove silent pytest.importorskip, add explicit skip conditions for backend-specific tests - Move process_id function from backend to distribution module - Update imports to use centralized LazyModule for orbax.checkpoint - Test across all backends (JAX, TensorFlow, PyTorch) - all passing
276ea9a to
b56dc7b
Compare
| checkpoints if there might be pending save operations. | ||
| """ | ||
| # Wait for any async operations to complete | ||
| while self.checkpointer.is_saving_in_progress(): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Probably better to use checkpointer.wait() here, unless you want to log things periodicially.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I am using Orbax's experimental v1 API, but the checkpointer.wait() method doesn't exist in our installed version. The GitHub link you provided might be from a development branch, i can see it here https://github.com/google/orbax/blob/main/checkpoint/orbax/checkpoint/experimental/v1/_src/training/checkpointer.py#L521 but not on the installed released version.
621f566 to
eb7855d
Compare
…s expected failures Neural networks are inherently non-deterministic, so pipeline consistency checks should be skipped rather than fail. Added check_pipeline_consistency to EXPECTED_FAILED_CHECKS for all sklearn wrapper types.
163bc4b to
cd881dd
Compare
- Avoid unnecessary numpy conversion in _get_state_tree() for JAX backend - Preserve JAX arrays during saving instead of converting to numpy - Maintain cross-backend compatibility with proper loading conversions - Update async waiting to use CheckpointManager.wait_until_finished() - Implement AlwaysSavePolicy for reliable save decisions - Add expected failures for sklearn tests due to neural network non-determinism
c14c30e to
b7a0dff
Compare
- Preserve JAX arrays during saving when jax.monitoring.record_scalar is available - Fall back to numpy conversion for older JAX versions that don't have record_scalar - Maintain cross-backend compatibility while avoiding unnecessary conversions - Update async waiting to use CheckpointManager.wait_until_finished() - Implement AlwaysSavePolicy for reliable save decisions - Add expected failures for sklearn tests due to neural network non-determinism
fef84a0 to
33f4e66
Compare
Supports following feature